{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [13:42:42] Enabling RDKit 2019.09.1 jupyter extensions\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fbb761d30d0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch_geometric.data import DataLoader\n",
    "import torch_geometric\n",
    "import torch.distributions as D\n",
    "import matplotlib.pyplot as plt\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem, Draw, rdFMCS, rdMolTransforms\n",
    "from rdkit.Chem.rdMolAlign import AlignMol\n",
    "from rdkit.Chem import PandasTools\n",
    "from rdkit import rdBase\n",
    "import glob\n",
    "import os\n",
    "\n",
    "import deepdock\n",
    "from deepdock.utils.distributions import *\n",
    "from deepdock.utils.data import *\n",
    "from deepdock.models import *\n",
    "\n",
    "%matplotlib inline\n",
    "np.random.seed(123)\n",
    "torch.cuda.manual_seed_all(123)\n",
    "torch.manual_seed(123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "#device = 'cpu'\n",
    "\n",
    "ligand_model = LigandNet(28, residual_layers=10, dropout_rate=0.10)\n",
    "target_model = TargetNet(4, residual_layers=10, dropout_rate=0.10)\n",
    "model = DeepDock(ligand_model, target_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=7.).to(device)\n",
    "\n",
    "checkpoint = torch.load(deepdock.__path__[0]+'/../Trained_models/DeepDock_pdbbindv2019_13K_minTestLoss.chk', map_location=torch.device(device))\n",
    "model.load_state_dict(checkpoint['model_state_dict']) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Complexes from pdbBind: 285\n"
     ]
    }
   ],
   "source": [
    "db_complex = PDBbind_complex_dataset(data_path=deepdock.__path__[0]+'/../data/dataset_CASF-2016_285.tar', \n",
    "                                     min_target_nodes=None, max_ligand_nodes=None)\n",
    "print('Complexes from pdbBind:', len(db_complex))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6.46 s, sys: 333 ms, total: 6.79 s\n",
      "Wall time: 4.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from torch_scatter import scatter_add\n",
    "results = []\n",
    "\n",
    "model.eval()\n",
    "loader = DataLoader(db_complex, batch_size=20, shuffle=False)\n",
    "\n",
    "for data in loader:\n",
    "    ligand, target, _, pdbid = data\n",
    "    ligand, target = ligand.to(device), target.to(device)\n",
    "    pi, sigma, mu, dist, atom_types, bond_types, batch = model(ligand, target)\n",
    "\n",
    "    normal = Normal(mu, sigma)\n",
    "    logprob = normal.log_prob(dist.expand_as(normal.loc))\n",
    "    logprob += torch.log(pi)\n",
    "    prob = logprob.exp().sum(1)\n",
    "    prob_all = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "        \n",
    "    prob[torch.where(dist > 10)[0]] = 0.\n",
    "    prob_10 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "        \n",
    "    prob[torch.where(dist > 7)[0]] = 0.\n",
    "    prob_7 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "\n",
    "    prob[torch.where(dist > 5)[0]] = 0.\n",
    "    prob_5 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "\n",
    "    prob[torch.where(dist > 3)[0]] = 0.\n",
    "    prob_3 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "        \n",
    "    prob = torch.stack([prob_3, prob_5, prob_7, prob_10, prob_all],dim=1)\n",
    "    #print(pdbid, cpd_name, prob_all.cpu().detach().numpy())\n",
    "    results.append(np.concatenate([np.expand_dims(pdbid, axis=1), \n",
    "                                   prob.cpu().detach().numpy()], axis=1))\n",
    "      "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(285, 6)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PDB_ID</th>\n",
       "      <th>Score_3A</th>\n",
       "      <th>Score_5A</th>\n",
       "      <th>Score_7A</th>\n",
       "      <th>Score_10A</th>\n",
       "      <th>Score_all</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>4k18</td>\n",
       "      <td>64.25368680000449</td>\n",
       "      <td>420.5898705181299</td>\n",
       "      <td>1376.762661813891</td>\n",
       "      <td>1410.0162568367755</td>\n",
       "      <td>1410.01638602799</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4qac</td>\n",
       "      <td>83.2844061298711</td>\n",
       "      <td>450.01256968999587</td>\n",
       "      <td>1128.0662764096696</td>\n",
       "      <td>1154.7673063010277</td>\n",
       "      <td>1154.7690904503804</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1o3f</td>\n",
       "      <td>187.7132053065723</td>\n",
       "      <td>778.5596122182768</td>\n",
       "      <td>1779.2439875640941</td>\n",
       "      <td>1823.1001001165655</td>\n",
       "      <td>1823.110781221914</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4ih7</td>\n",
       "      <td>35.630378250691514</td>\n",
       "      <td>270.11676199153</td>\n",
       "      <td>787.7759051254867</td>\n",
       "      <td>814.7537476029596</td>\n",
       "      <td>814.755215144911</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3dx1</td>\n",
       "      <td>49.94476787861658</td>\n",
       "      <td>202.24570100387868</td>\n",
       "      <td>449.72902901247784</td>\n",
       "      <td>462.38353806385936</td>\n",
       "      <td>462.39191325604116</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  PDB_ID            Score_3A            Score_5A            Score_7A  \\\n",
       "0   4k18   64.25368680000449   420.5898705181299   1376.762661813891   \n",
       "1   4qac    83.2844061298711  450.01256968999587  1128.0662764096696   \n",
       "2   1o3f   187.7132053065723   778.5596122182768  1779.2439875640941   \n",
       "3   4ih7  35.630378250691514     270.11676199153   787.7759051254867   \n",
       "4   3dx1   49.94476787861658  202.24570100387868  449.72902901247784   \n",
       "\n",
       "            Score_10A           Score_all  \n",
       "0  1410.0162568367755    1410.01638602799  \n",
       "1  1154.7673063010277  1154.7690904503804  \n",
       "2  1823.1001001165655   1823.110781221914  \n",
       "3   814.7537476029596    814.755215144911  \n",
       "4  462.38353806385936  462.39191325604116  "
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = np.concatenate(results, axis=0)\n",
    "results = pd.DataFrame(np.asarray(results), columns=['PDB_ID', 'Score_3A', 'Score_5A', 'Score_7A', 'Score_10A', 'Score_all'])\n",
    "results.to_csv('Score_CoreSet_docking_CASF2016.csv', index=False)\n",
    "print(results.shape)\n",
    "results.head() "
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('Score_CoreSet_docking_CASF2016.csv')\n",
    "\n",
    "df1 = df[['PDB_ID', 'Score_3A']]\n",
    "df1.columns= ['#code', 'score']\n",
    "df1.to_csv('ScoringPower_Deepdock/scores/Deepdock_3A.dat', index=False, sep='\\t')\n",
    "    \n",
    "df1 = df[['PDB_ID', 'Score_5A']]\n",
    "df1.columns= ['#code', 'score']\n",
    "df1.to_csv('ScoringPower_Deepdock/scores/Deepdock_5A.dat', index=False, sep='\\t')\n",
    "    \n",
    "df1 = df[['PDB_ID', 'Score_7A']]\n",
    "df1.columns= ['#code', 'score']\n",
    "df1.to_csv('ScoringPower_Deepdock/scores/Deepdock_7A.dat', index=False, sep='\\t')\n",
    "    \n",
    "df1 = df[['PDB_ID', 'Score_10A']]\n",
    "df1.columns= ['#code', 'score']\n",
    "df1.to_csv('ScoringPower_Deepdock/scores/Deepdock_10A.dat', index=False, sep='\\t')\n",
    "    \n",
    "df1 = df[['PDB_ID', 'Score_all']]\n",
    "df1.columns= ['#code', 'score']\n",
    "df1.to_csv('ScoringPower_Deepdock/scores/Deepdock_all.dat', index=False, sep='\\t')"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "python /data/CASF-2016/power_scoring/scoring_power.py -c /data/CASF-2016/power_scoring/CoreSet.dat -s ./scores/Deepdock_all.dat -p 'positive' -o 'Deepdock' > ScoringPower_Deepdock_all.out\n",
    "\n",
    "python /data/CASF-2016/power_ranking/ranking_power.py -c /data/CASF-2016/power_ranking/CoreSet.dat -s ./scores/Deepdock_all.dat -p 'positive' -o 'Deepdock' > RankingPower_Deepdock_all.out\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
